# MIT License
#
# Copyright (c) 2023 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections
import os
import pickle

from nndct_shared.pruning import logging
from nndct_shared.pruning import pruning_lib
from nndct_shared.utils import io, logging
from typing import Mapping, Any, List
import json

class GroupMetrics(object):

  def __init__(self, nodes, metrics, num_groups: int = 1):
    self.nodes = nodes
    self.metrics = metrics
    self.num_groups = num_groups

  def __repr__(self):
    nodes = '\n'.join(['    ' + node for node in self.nodes])
    metrics = '\n'.join(
        ['    ({}, {})'.format(m.sparsity, m.value) for m in self.metrics if m])
    return ('GroupMetrics {\n'
            '  nodes {\n%s\n  }\n'
            '  metrics {\n%s\n  }\n'
            '}') % (nodes, metrics)

  def serialize(self) -> Mapping[str, Any]:
    return {
        "nodes": self.nodes,
        "metrics": [[m.sparsity, m.value] for m in self.metrics if m],
        "num_groups": self.num_groups
    }

  @classmethod
  def deserialize(cls, data: Mapping[str, Any]):
    return cls(
        nodes=data["nodes"],
        metrics=[AnaMetric(item[0], item[1]) for item in data["metrics"]],
        num_groups=data.get("num_groups", 1))

class NetSensitivity(object):
  """The sensitivity results of the network generated by model analysis."""

  def __init__(self, groups: List[GroupMetrics] = [], graph_digest: str = None):
    self._groups = [g for g in groups]
    self._uncompleted_steps = []
    self._graph_digest: str = graph_digest

  def add_group(self, nodes, metrics, num_groups: int = 1):
    self._groups.append(GroupMetrics(nodes, metrics, num_groups))

  @property
  def uncompleted_steps(self):
    return self._uncompleted_steps

  @uncompleted_steps.setter
  def uncompleted_steps(self, uncompleted_steps):
    self._uncompleted_steps = uncompleted_steps

  @property
  def graph_digest(self) -> str:
    return self._graph_digest

  @graph_digest.setter
  def graph_digest(self, graph_digest: str) -> None:
    self._graph_digest = graph_digest

  @property
  def groups(self):
    return self._groups

  @groups.setter
  def groups(self, groups):
    self._groups = groups

  def prunable_groups_by_threshold(self, threshold, excludes=[]):
    prunable_groups = []
    for group in self._groups:
      skip = False
      for node in group.nodes:
        if node in excludes:
          skip = True
          break

      if skip:
        continue

      # Four common metric distributions:
      # w/o negative:
      # [120, 100, 54, ..., 0.14, 0.08]
      # [0.1, 0.2, 0.5, ..., 1.3, 1.8]
      #
      # with negative:
      # [0.85, 0.63, ..., 0.05, -0.01, -0.9]
      # [-40, -30, ..., 11, 50, 70]
      res = None
      baseline_value = group.metrics[0].value
      for sparsity, value in group.metrics:
        bias = baseline_value - value
        if abs(bias / (baseline_value + 1e-5)) > threshold:
          break
        res = sparsity

      if res is not None:
        prunable_groups.append(
            pruning_lib.PrunableGroup(group.nodes, res, group.num_groups))

    return prunable_groups

  def __repr__(self):
    strs = []
    uncompleted_steps = ','.join(str(i) for i in self._uncompleted_steps)
    strs.append('uncompleted_steps={%s}' % uncompleted_steps)
    for group in self._groups:
      strs.append(repr(group))
    return "\n".join(strs)

def load_sens(filepath):
  sens = NetSensitivity()
  try:
    with open(filepath, 'rb') as f:
      sens_info = pickle.load(f)
      sens.groups = sens_info['groups']
      sens.uncompleted_steps = sens_info['uncompleted_steps']
      sens.graph_digest = sens_info.get("graph_digest")
  except Exception as e:
    logging.info(f"Reading sens from file {filepath} with pickle failed: {e}")
    logging.info(f"Now try reading sens from file {filepath} with json")
    with open(filepath, "r") as f:
      sens_info = json.load(f)
      sens.groups = [
          GroupMetrics.deserialize(group) for group in sens_info["groups"]
      ]
      sens.uncompleted_steps = sens_info["uncompleted_steps"]
      sens.graph_digest = sens_info.get("graph_digest")
  return sens

def save_sens(sens: NetSensitivity, filepath: str):
  sens_info = {
      'groups': [group.serialize() for group in sens.groups],
      'uncompleted_steps': sens.uncompleted_steps,
      'graph_digest': sens.graph_digest
  }
  with open(filepath, 'w') as f:
    json.dump(sens_info, f, indent=2)

def inspect_sens(filepath):
  print(load_sens(filepath))

class AnaMetric(collections.namedtuple('AnaMetric', ['sparsity', 'value'])):
  pass

class ModelAnalyser(object):
  """Class for performing analysis on model."""

  ExpsPerGroup = 9

  def __init__(self, graph, excludes, with_group_conv: bool = False):
    self._cur_step = 0

    groups = pruning_lib.group_nodes(graph, excludes, with_group_conv)
    self._total_steps = len(groups) * ModelAnalyser.ExpsPerGroup + 1
    self._uncompleted_steps = [i for i in range(self._total_steps)]
    self._groups = groups
    self._metrics = [None for i in range(self._total_steps)]

  def recover_state(self, net_sens):
    self._uncompleted_steps = net_sens.uncompleted_steps
    step = 0
    for group in net_sens.groups:
      for metric in group.metrics:
        self._metrics[step] = metric
        step += 1
        break
      break
    for i in range(len(net_sens.groups)):
      for j in range(1, len(net_sens.groups[i].metrics)):
        metric = net_sens.groups[i].metrics[j]
        self._metrics[step] = metric
        step += 1

  def uncompleted_steps(self):
    return self._uncompleted_steps

  def steps(self):
    return self._total_steps

  def _eval_plan(self, step):
    # Step 0 for baseline.
    if step == 0:
      return 0, None
    group, exp = divmod(step - 1, ModelAnalyser.ExpsPerGroup)
    return group, (exp + 1) * 0.1

  def spec(self, step):
    spec = pruning_lib.PruningSpec()
    if step > 0:
      group_idx, sparsity = self._eval_plan(step)
      group = self._groups[group_idx]
      spec.add_group(
          pruning_lib.PrunableGroup(group.nodes, sparsity, group.num_groups))
    # Empty spec for baseline.
    return spec

  @abc.abstractmethod
  def task(self):
    pass

  def record(self, step, result):
    if step >= self._total_steps:
      raise IndexError

    logging.vlog(3, "Ana step {} record: {}".format(step, result))
    _, sparsity = self._eval_plan(step)
    self._metrics[step] = AnaMetric(sparsity, result)
    self._uncompleted_steps.remove(step)

  def save(self, filepath):
    net_sens = NetSensitivity()
    net_sens.uncompleted_steps = self._uncompleted_steps
    start_index = 1
    end_index = start_index + ModelAnalyser.ExpsPerGroup
    for group in self._groups:
      metrics = [self._metrics[0]] + self._metrics[start_index:end_index]
      net_sens.add_group(group.nodes, metrics, group.num_groups)
      start_index = end_index
      end_index += ModelAnalyser.ExpsPerGroup
    io.create_work_dir(os.path.dirname(filepath))
    save_sens(net_sens, filepath)
